Skip to content

Add parameter to DiceMetric and DiceHelper classes#8774

Open
VijayVignesh1 wants to merge 9 commits intoProject-MONAI:devfrom
VijayVignesh1:8733-per-component-dice-metric
Open

Add parameter to DiceMetric and DiceHelper classes#8774
VijayVignesh1 wants to merge 9 commits intoProject-MONAI:devfrom
VijayVignesh1:8733-per-component-dice-metric

Conversation

@VijayVignesh1
Copy link

@VijayVignesh1 VijayVignesh1 commented Mar 13, 2026

Fixes #8733

Description

This PR adds support for connected component-based Dice metric calculation to the existing DiceMetric and DiceHelper classes.

Changes

  • Added per_component: bool = False to both DiceMetric and DiceHelper constructors
  • Implemented compute_cc_dice method that calculates Dice scores for each connected component individually
  • Voronoi regions: Added compute_voronoi_regions_fast method for efficient connected component assignment without external cc3d dependency
  • Added input shape validation requiring 5D binary segmentation with 2 channels (background + foreground) when per_component=True
  • Updated first_ch calculation to properly exclude background channel when using per-component mode

Reference

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 13, 2026

📝 Walkthrough

Walkthrough

Adds a per_component mode to DiceMetric, DiceHelper, and compute_dice to compute Dice per connected component. When enabled, inputs must be 5D binary segmentation with exactly 2 channels; ground-truth foreground is decomposed into connected components, Voronoi regions are computed, and per-component Dice scores are produced via new DiceHelper methods compute_voronoi_regions_fast and compute_cc_dice. The per_component flag is propagated through initializers and compute paths; DiceHelper.call validates input shape and raises ValueError for mismatches. Tests were added to validate per-component values and input-dimension checks (skipped if scipy.ndimage is unavailable).

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 54.55% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive Title is vague and generic. It does not clearly convey the main innovation (per-component/connected-component Dice metric support), making it difficult for reviewers to understand the primary change. Revise title to highlight the core feature, e.g., 'Add per-component Dice metric support via connected component analysis' or 'Implement connected component-based Dice calculation.'
✅ Passed checks (3 passed)
Check name Status Explanation
Description check ✅ Passed Description is thorough and complete, covering the solution, implementation details, references, and all template sections properly filled. Types of changes are documented with checkmarks.
Linked Issues check ✅ Passed Changes implement core requirements from #8733: decompose ground truth into components, assign voxels via Voronoi partitioning, evaluate Dice per component, and support binary segmentation with seamless integration.
Out of Scope Changes check ✅ Passed All changes directly support per-component Dice calculation: new methods (compute_cc_dice, compute_voronoi_regions_fast), parameter additions, validation, and comprehensive tests align with linked issue #8733.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (3)
monai/metrics/meandice.py (1)

418-426: Wasted computation when per_component=True.

Lines 420-423 compute channel Dice, then lines 424-425 discard it and overwrite c_list. Move the branch earlier.

Proposed fix
         for b in range(y_pred.shape[0]):
-            c_list = []
-            for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
-                x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
-                x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
-                c_list.append(self.compute_channel(x_pred, x))
             if self.per_component:
                 c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))]
+            else:
+                c_list = []
+                for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
+                    x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
+                    x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
+                    c_list.append(self.compute_channel(x_pred, x))
             data.append(torch.stack(c_list))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` around lines 418 - 426, The loop is doing wasted
work: it always computes per-channel Dice via compute_channel for each c and
only when self.per_component is True it discards those results and replaces
c_list with a compute_cc_dice call. Change the logic inside the for b in
range(...) loop to check self.per_component before computing channels; if
self.per_component is True, directly set c_list =
[self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))] and
skip the per-channel compute_channel loop and related x_pred/x extraction,
otherwise run the existing per-channel path that builds c_list with
compute_channel as before. Ensure references to y_pred, y, compute_channel,
compute_cc_dice, c_list and per_component are used so the branch correctly
short-circuits the expensive channel computations.
tests/metrics/test_compute_meandice.py (2)

253-276: Test data construction is hard to follow; expected value undocumented.

The lambda-walrus pattern obscures setup. Consider a helper function. Also document how 0.5120 was derived for maintainability.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/metrics/test_compute_meandice.py` around lines 253 - 276, TEST_CASE_16
uses a lambda-walrus pattern (variables y and y_pred inside TEST_CASE_16) that
makes the test data setup hard to read and omits explanation of the expected
0.5120 value; extract the tensor construction into a small descriptive helper
(e.g., build_test_case_16_tensors or make_meandice_case_16) and replace the
inline lambdas with calls to that helper, and add a short comment next to the
expected value explaining how 0.5120 was computed (e.g., describe overlapping
voxel counts and Dice formula for the two shifted cubes) so the test is readable
and the expected number is documented.

337-339: Shape mismatch may obscure test intent.

Both tensors are 4D (not 5D) and have 3 channels (not 2). The spatial mismatch (144 vs 145) is irrelevant to the validation. Use matching shapes to clarify:

-            DiceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))
+            DiceMetric(per_component=True)(torch.ones([3, 3, 64, 64]), torch.ones([3, 3, 64, 64]))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/metrics/test_compute_meandice.py` around lines 337 - 339, The test
currently uses two 4D tensors with mismatched spatial sizes and 3 channels,
which obscures the intent to validate dimensionality; update
test_input_dimensions so both tensors have identical shapes but still 4D to
trigger the ValueError (e.g., use torch.ones([3, 2, 144, 144]) for both),
ensuring the failure comes from incorrect dimensionality for DiceMetric rather
than a spatial-size mismatch or wrong channel count.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@monai/metrics/meandice.py`:
- Around line 14-17: The module imports SciPy unconditionally causing CI
failures when SciPy is not installed; change the top-level imports to use
MONAI's optional_import pattern to import distance_transform_edt,
generate_binary_structure and label (sn_label) and expose a has_scipy flag, then
in compute_voronoi_regions_fast check has_scipy and raise a clear RuntimeError
if False; update references to
sn_label/distance_transform_edt/generate_binary_structure in the file to use the
optionally imported symbols so runtime usage is guarded.
- Line 416: The code currently sets first_ch based on a combined condition which
silently ignores include_background when per_component is True; update the logic
in the MeanDice/meandice implementation to detect the conflicting flags
(self.per_component True and self.include_background False) and emit a clear
warning (e.g., warnings.warn or using the module logger) that include_background
will be ignored when per_component is enabled, then keep the existing behavior
for first_ch (set first_ch=1) to preserve compatibility; reference the
attributes self.include_background, self.per_component and the local variable
first_ch so reviewers can locate and adjust the check and add the warning.
- Around line 300-321: The compute_voronoi_regions_fast function's docstring
lacks a Returns section and the function always returns a CPU tensor
(torch.from_numpy) even if the original input was a CUDA tensor; update the
docstring to include a Returns: description and type (torch.Tensor on same
device as input) and change the implementation to preserve input type/device:
accept numpy array or torch.Tensor for labels, record the original device and
dtype (if torch.Tensor), convert input to CPU numpy for EDT processing, then
convert the resulting voronoi numpy array back to a torch.Tensor and move it to
the original device and appropriate dtype before returning; reference
compute_voronoi_regions_fast, labels, edt_input, indices, and voronoi when
locating where to apply these changes.
- Around line 323-364: The compute_cc_dice method's docstring and
empty-ground-truth handling are incorrect: update the docstring for
compute_cc_dice to state the actual expected input shapes (e.g., tensors that
may include batch and channel dims such as (1, C, D, H, W) or
per-channel/per-item spatial tensors) and then change the empty-GT branch (the
y_idx[0].sum() == 0 case) to consult self.ignore_empty (return
torch.tensor(0.0/1.0 or skip/ignore according to class semantics) instead of
always appending 1.0/0.0), and move the inf/nan replacement logic (the two
torch.where lines that sanitize values) out of the else block so they run for
both empty and non-empty cases; refer to symbols compute_cc_dice, y_idx,
y_pred_idx, self.ignore_empty, cc_assignment, uniq/inv/hist/dice_scores to
locate and update the logic and docstring.

---

Nitpick comments:
In `@monai/metrics/meandice.py`:
- Around line 418-426: The loop is doing wasted work: it always computes
per-channel Dice via compute_channel for each c and only when self.per_component
is True it discards those results and replaces c_list with a compute_cc_dice
call. Change the logic inside the for b in range(...) loop to check
self.per_component before computing channels; if self.per_component is True,
directly set c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0),
y=y[b].unsqueeze(0))] and skip the per-channel compute_channel loop and related
x_pred/x extraction, otherwise run the existing per-channel path that builds
c_list with compute_channel as before. Ensure references to y_pred, y,
compute_channel, compute_cc_dice, c_list and per_component are used so the
branch correctly short-circuits the expensive channel computations.

In `@tests/metrics/test_compute_meandice.py`:
- Around line 253-276: TEST_CASE_16 uses a lambda-walrus pattern (variables y
and y_pred inside TEST_CASE_16) that makes the test data setup hard to read and
omits explanation of the expected 0.5120 value; extract the tensor construction
into a small descriptive helper (e.g., build_test_case_16_tensors or
make_meandice_case_16) and replace the inline lambdas with calls to that helper,
and add a short comment next to the expected value explaining how 0.5120 was
computed (e.g., describe overlapping voxel counts and Dice formula for the two
shifted cubes) so the test is readable and the expected number is documented.
- Around line 337-339: The test currently uses two 4D tensors with mismatched
spatial sizes and 3 channels, which obscures the intent to validate
dimensionality; update test_input_dimensions so both tensors have identical
shapes but still 4D to trigger the ValueError (e.g., use torch.ones([3, 2, 144,
144]) for both), ensuring the failure comes from incorrect dimensionality for
DiceMetric rather than a spatial-size mismatch or wrong channel count.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: f8c5c98e-0bb3-413d-9471-3bef41a45cfa

📥 Commits

Reviewing files that changed from the base of the PR and between daaedaa and 41e52c1.

📒 Files selected for processing (2)
  • monai/metrics/meandice.py
  • tests/metrics/test_compute_meandice.py

@VijayVignesh1 VijayVignesh1 marked this pull request as draft March 13, 2026 15:40
…itai - docstring issues, ignore_empty bug

Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
…itai - docstring issues, ignore_empty bug

Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
@VijayVignesh1 VijayVignesh1 marked this pull request as ready for review March 13, 2026 20:22
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
monai/metrics/meandice.py (2)

427-427: ⚠️ Potential issue | 🟡 Minor

include_background is still silently ignored with per_component=True.

Line 427 forces foreground-only behavior without signaling it.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` at line 427, The current assignment to first_ch in
MeanDice (meandice.py) ignores include_background when per_component is True;
change the logic so include_background is honored regardless of per_component by
setting first_ch based solely on self.include_background (e.g., first_ch = 0 if
self.include_background else 1) instead of conditioning on not
self.per_component; update any related comments/tests that assumed the previous
behavior.

318-328: ⚠️ Potential issue | 🟠 Major

Per-component Dice is not CUDA-safe.

Line 318 uses np.asarray(labels), which breaks for CUDA tensors; Line 328 returns a CPU tensor, and Line 360 then mixes devices.

Proposed fix
-    def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
+    def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
@@
-        x = np.asarray(labels)
+        labels_t = labels if isinstance(labels, torch.Tensor) else torch.as_tensor(labels)
+        in_device = labels_t.device
+        x = labels_t.detach().cpu().numpy()
@@
-        if num == 0:
-            return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
+        if num == 0:
+            return torch.zeros_like(labels_t, dtype=torch.int32, device=in_device)
@@
-        return torch.from_numpy(voronoi)
+        return torch.from_numpy(voronoi).to(device=in_device, dtype=torch.int32)
#!/bin/bash
set -euo pipefail

# Verify current implementation uses numpy conversion without explicit CPU transfer
rg -n -C2 'def compute_voronoi_regions_fast|np\.asarray\(labels\)|torch\.from_numpy\(voronoi\)|compute_voronoi_regions_fast\(y_idx\[0\]\)' monai/metrics/meandice.py

# Confirm there is no explicit detach+cpu numpy conversion in this function body
python - <<'PY'
from pathlib import Path
text = Path("monai/metrics/meandice.py").read_text()
start = text.index("def compute_voronoi_regions_fast")
end = text.index("def compute_cc_dice")
chunk = text[start:end]
print("contains_detach_cpu_numpy:", ".detach().cpu().numpy()" in chunk or ".cpu().numpy()" in chunk)
PY

Also applies to: 356-361

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` around lines 318 - 328, The function
compute_voronoi_regions_fast currently uses np.asarray(labels) and
torch.from_numpy(voronoi) which breaks for CUDA tensors; capture the input
tensor's device and dtype first (e.g., orig_device = labels.device if
isinstance(labels, torch.Tensor) else torch.device('cpu')), convert safely to
CPU numpy via labels = labels.detach().cpu().numpy() (or leave numpy arrays
unchanged), run the existing numpy logic, then convert the result back using
torch.from_numpy(voronoi).to(device=orig_device, dtype=torch.int32) or
torch.as_tensor(voronoi, device=orig_device, dtype=torch.int32) so the returned
tensor is on the same device as the input; update both
compute_voronoi_regions_fast and the similar code at the other location (lines
~356-361 / compute_cc_dice caller) to follow this pattern.
🧹 Nitpick comments (2)
tests/metrics/test_compute_meandice.py (1)

334-337: Add per-component validation tests for invalid y shape/channel.

test_input_dimensions covers only one invalid input pattern. Add cases for y not being (B, 2, D, H, W) and for y_pred/y channel mismatch.

As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/metrics/test_compute_meandice.py` around lines 334 - 337, Extend the
test_input_dimensions in tests/metrics/test_compute_meandice.py to add two more
invalid-shape cases: (1) verify DiceMetric(per_component=True) raises ValueError
when y has wrong channel count (e.g., y shape not (B,2,D,H,W) such as
torch.ones([3,1,144,144]) or torch.ones([3,3,144,144]) depending on 2D/3D
expectation), and (2) verify DiceMetric(per_component=True) raises ValueError
when y_pred and y have mismatched channel counts (call
DiceMetric(per_component=True)(y_pred, y) where y_pred has 2 channels and y has
1 or vice versa). Reference the DiceMetric class and the existing
test_input_dimensions to add these assertions so coverage includes invalid y
shapes and channel mismatches.
monai/metrics/meandice.py (1)

429-437: Skip per-channel Dice work when per_component=True.

Line 431-434 computes channel Dice, then Line 436 overwrites c_list. This is unnecessary work on every batch item.

Proposed refactor
         data = []
         for b in range(y_pred.shape[0]):
+            if self.per_component:
+                data.append(self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0)).unsqueeze(0))
+                continue
             c_list = []
             for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
                 x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
                 x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
                 c_list.append(self.compute_channel(x_pred, x))
-            if self.per_component:
-                c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))]
             data.append(torch.stack(c_list))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` around lines 429 - 437, The loop currently always
computes per-channel Dice via compute_channel for each class then overwrites
c_list when self.per_component is True, causing unnecessary work; update the
logic in the batch loop (the block using variables b, c_list, first_ch,
n_pred_ch and calling compute_channel and compute_cc_dice) to short-circuit when
self.per_component is True—i.e., if self.per_component is True, skip the inner
per-channel loop entirely and directly set c_list =
[self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))];
otherwise run the existing per-channel computation using compute_channel. Ensure
you preserve behavior for n_pred_ch == 1 and that
data.append(torch.stack(c_list)) still executes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@monai/metrics/meandice.py`:
- Around line 421-425: The per_component branch only validates y_pred; add the
same validation for y so incorrect shapes or channel counts on the ground truth
produce an immediate error. In the function/method that contains the existing
check (the block referencing self.per_component and y_pred), validate that y
also has 5 dimensions and y.shape[1] == 2 (or that y.shape matches y_pred) and
raise a ValueError with a message parallel to the existing one (e.g.,
"per_component requires 5D binary segmentation with 2 channels... Got shape
{y.shape}"). Ensure you reference the same symbol names (self.per_component,
y_pred, y) so the check runs before any computation that assumes the 5D
two-channel layout.

In `@tests/metrics/test_compute_meandice.py`:
- Around line 275-276: Remove the class-level `@unittest.skipUnless`(has_ndimage,
...) on TestComputeMeanDice and instead apply that skip only to the tests that
exercise the per_component code-path; identify and decorate the specific methods
(e.g., any test methods named test_*per_component* or those that call
compute_mean_dice(..., per_component=True) such as test_mean_dice_per_component)
with `@unittest.skipUnless`(has_ndimage, "Requires scipy.ndimage."); keep other
tests in TestComputeMeanDice unskipped so non-ndimage paths still run.

---

Duplicate comments:
In `@monai/metrics/meandice.py`:
- Line 427: The current assignment to first_ch in MeanDice (meandice.py) ignores
include_background when per_component is True; change the logic so
include_background is honored regardless of per_component by setting first_ch
based solely on self.include_background (e.g., first_ch = 0 if
self.include_background else 1) instead of conditioning on not
self.per_component; update any related comments/tests that assumed the previous
behavior.
- Around line 318-328: The function compute_voronoi_regions_fast currently uses
np.asarray(labels) and torch.from_numpy(voronoi) which breaks for CUDA tensors;
capture the input tensor's device and dtype first (e.g., orig_device =
labels.device if isinstance(labels, torch.Tensor) else torch.device('cpu')),
convert safely to CPU numpy via labels = labels.detach().cpu().numpy() (or leave
numpy arrays unchanged), run the existing numpy logic, then convert the result
back using torch.from_numpy(voronoi).to(device=orig_device, dtype=torch.int32)
or torch.as_tensor(voronoi, device=orig_device, dtype=torch.int32) so the
returned tensor is on the same device as the input; update both
compute_voronoi_regions_fast and the similar code at the other location (lines
~356-361 / compute_cc_dice caller) to follow this pattern.

---

Nitpick comments:
In `@monai/metrics/meandice.py`:
- Around line 429-437: The loop currently always computes per-channel Dice via
compute_channel for each class then overwrites c_list when self.per_component is
True, causing unnecessary work; update the logic in the batch loop (the block
using variables b, c_list, first_ch, n_pred_ch and calling compute_channel and
compute_cc_dice) to short-circuit when self.per_component is True—i.e., if
self.per_component is True, skip the inner per-channel loop entirely and
directly set c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0),
y=y[b].unsqueeze(0))]; otherwise run the existing per-channel computation using
compute_channel. Ensure you preserve behavior for n_pred_ch == 1 and that
data.append(torch.stack(c_list)) still executes.

In `@tests/metrics/test_compute_meandice.py`:
- Around line 334-337: Extend the test_input_dimensions in
tests/metrics/test_compute_meandice.py to add two more invalid-shape cases: (1)
verify DiceMetric(per_component=True) raises ValueError when y has wrong channel
count (e.g., y shape not (B,2,D,H,W) such as torch.ones([3,1,144,144]) or
torch.ones([3,3,144,144]) depending on 2D/3D expectation), and (2) verify
DiceMetric(per_component=True) raises ValueError when y_pred and y have
mismatched channel counts (call DiceMetric(per_component=True)(y_pred, y) where
y_pred has 2 channels and y has 1 or vice versa). Reference the DiceMetric class
and the existing test_input_dimensions to add these assertions so coverage
includes invalid y shapes and channel mismatches.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 2469c914-d7e5-4549-930b-3212056a1266

📥 Commits

Reviewing files that changed from the base of the PR and between 41e52c1 and ba2e0b3.

📒 Files selected for processing (2)
  • monai/metrics/meandice.py
  • tests/metrics/test_compute_meandice.py

…eck bug

Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
tests/metrics/test_compute_meandice.py (1)

256-272: Test data note: batches 1-4 have all-zero tensors.

Valid for testing empty GT handling (ignore_empty=False returns 1.0), but technically not proper one-hot encoding. Consider adding a comment explaining intent.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/metrics/test_compute_meandice.py` around lines 256 - 272, TEST_CASE_16
uses y and y_hat where batches 1-4 are all-zero (not proper one-hot), so clarify
intent: update the test comment near TEST_CASE_16 to state that y and y_hat
intentionally include all-zero batches to validate per_component DiceMetric
behavior with ignore_empty=False (expecting 1.0), reference the variables y,
y_hat and the test case name TEST_CASE_16; do not change data values, only add a
concise comment explaining that these batches are intentionally empty and used
to test empty-GT handling.
monai/metrics/meandice.py (2)

430-438: Wasteful computation when per_component=True.

Lines 432-435 compute per-channel Dice, but when per_component=True, line 437 replaces c_list entirely, discarding that work.

♻️ Suggested optimization
         for b in range(y_pred.shape[0]):
             c_list = []
-            for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
-                x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
-                x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
-                c_list.append(self.compute_channel(x_pred, x))
             if self.per_component:
                 c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))]
+            else:
+                for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
+                    x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
+                    x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
+                    c_list.append(self.compute_channel(x_pred, x))
             data.append(torch.stack(c_list))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` around lines 430 - 438, The loop currently
computes per-channel Dice via compute_channel for every class into c_list and
then, if self.per_component is True, throws that work away by replacing c_list
with compute_cc_dice; to fix, short-circuit the per-component path: inside the
outer loop over b, check self.per_component first and only call
self.compute_cc_dice for that batch (y_pred[b].unsqueeze(0), y[b].unsqueeze(0))
to create c_list, otherwise run the existing per-channel computation using
compute_channel; this avoids wasted compute and ensures c_list is only populated
by the needed branch.

322-328: Minor: dtype inconsistency and allocation inefficiency.

Line 323 creates an intermediate tensor unnecessarily. Also, return dtype depends on platform (sn_label may return int32 or int64), but docstring promises int32.

♻️ Suggested fix
         if num == 0:
-            return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
+            return torch.zeros(x.shape, dtype=torch.int32)
         edt_input = np.ones(cc.shape, dtype=np.uint8)
         edt_input[cc > 0] = 0
         indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True)
         voronoi = cc[tuple(indices)]
-        return torch.from_numpy(voronoi)
+        return torch.from_numpy(voronoi.astype(np.int32))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` around lines 322 - 328, The early-return creates
an unnecessary tensor from x and the final return dtype can vary; change the
num==0 branch to directly return a torch tensor of zeros with the same shape as
cc and dtype=torch.int32 (avoid torch.from_numpy(x)). After computing voronoi =
cc[tuple(indices)], ensure voronoi is cast to a stable 32-bit integer numpy type
(e.g., voronoi = voronoi.astype(np.int32)) before converting with
torch.from_numpy so the returned tensor is always int32; update the code around
variables num, cc, indices, edt_input and voronoi in meandice.py accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@monai/metrics/meandice.py`:
- Line 428: The code currently forces first_ch = 1 when per_component is True,
silently ignoring include_background; update the logic in the MeanDice class
(where first_ch is computed) to detect the conflicting flags
(include_background=True and per_component=True) and emit a clear warning (using
warnings.warn or the module logger) stating that include_background will be
ignored in per_component mode and that first_ch is set to 1; keep the existing
behavior unless you intend to change semantics, but ensure the warning is raised
at construction or first use (e.g., in __init__ or the method computing
first_ch) so users are informed.

---

Nitpick comments:
In `@monai/metrics/meandice.py`:
- Around line 430-438: The loop currently computes per-channel Dice via
compute_channel for every class into c_list and then, if self.per_component is
True, throws that work away by replacing c_list with compute_cc_dice; to fix,
short-circuit the per-component path: inside the outer loop over b, check
self.per_component first and only call self.compute_cc_dice for that batch
(y_pred[b].unsqueeze(0), y[b].unsqueeze(0)) to create c_list, otherwise run the
existing per-channel computation using compute_channel; this avoids wasted
compute and ensures c_list is only populated by the needed branch.
- Around line 322-328: The early-return creates an unnecessary tensor from x and
the final return dtype can vary; change the num==0 branch to directly return a
torch tensor of zeros with the same shape as cc and dtype=torch.int32 (avoid
torch.from_numpy(x)). After computing voronoi = cc[tuple(indices)], ensure
voronoi is cast to a stable 32-bit integer numpy type (e.g., voronoi =
voronoi.astype(np.int32)) before converting with torch.from_numpy so the
returned tensor is always int32; update the code around variables num, cc,
indices, edt_input and voronoi in meandice.py accordingly.

In `@tests/metrics/test_compute_meandice.py`:
- Around line 256-272: TEST_CASE_16 uses y and y_hat where batches 1-4 are
all-zero (not proper one-hot), so clarify intent: update the test comment near
TEST_CASE_16 to state that y and y_hat intentionally include all-zero batches to
validate per_component DiceMetric behavior with ignore_empty=False (expecting
1.0), reference the variables y, y_hat and the test case name TEST_CASE_16; do
not change data values, only add a concise comment explaining that these batches
are intentionally empty and used to test empty-GT handling.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 623f6db3-2ab1-4bdf-8c74-90860dc9678d

📥 Commits

Reviewing files that changed from the base of the PR and between ba2e0b3 and d9bfb5d.

📒 Files selected for processing (2)
  • monai/metrics/meandice.py
  • tests/metrics/test_compute_meandice.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: Evaluation of Semantic Segmentation Metrics on a per-component basis

1 participant